{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. 将输入转化成TFRecord格式并保存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ../datasets/MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting ../datasets/MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting ../datasets/MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "TFRecord文件已保存。\n"
     ]
    }
   ],
   "source": [
    "# 定义函数转化变量类型。\n",
    "def _int64_feature(value):\n",
    "    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))\n",
    "\n",
    "def _bytes_feature(value):\n",
    "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "\n",
    "# 读取mnist数据。\n",
    "mnist = input_data.read_data_sets(\"../datasets/MNIST_data\",dtype=tf.uint8, one_hot=True)\n",
    "images = mnist.train.images\n",
    "labels = mnist.train.labels\n",
    "pixels = images.shape[1]\n",
    "num_examples = mnist.train.num_examples\n",
    "\n",
    "# 输出TFRecord文件的地址。\n",
    "filename = \"Records/output.tfrecords\"\n",
    "writer = tf.python_io.TFRecordWriter(filename)\n",
    "for index in range(num_examples):\n",
    "    image_raw = images[index].tostring()\n",
    "\n",
    "    example = tf.train.Example(features=tf.train.Features(feature={\n",
    "        'pixels': _int64_feature(pixels),\n",
    "        'label': _int64_feature(np.argmax(labels[index])),\n",
    "        'image_raw': _bytes_feature(image_raw)\n",
    "    }))\n",
    "    writer.write(example.SerializeToString())\n",
    "writer.close()\n",
    "print(\"TFRecord文件已保存。\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 读取TFRecord文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7 784\n",
      "3 784\n",
      "4 784\n",
      "6 784\n",
      "1 784\n",
      "8 784\n",
      "1 784\n",
      "0 784\n",
      "9 784\n",
      "8 784\n"
     ]
    }
   ],
   "source": [
    "# 读取文件。\n",
    "reader = tf.TFRecordReader()\n",
    "filename_queue = tf.train.string_input_producer([\"Records/output.tfrecords\"])\n",
    "_,serialized_example = reader.read(filename_queue)\n",
    "\n",
    "# 解析读取的样例。\n",
    "features = tf.parse_single_example(\n",
    "    serialized_example,\n",
    "    features={\n",
    "        'image_raw':tf.FixedLenFeature([],tf.string),\n",
    "        'pixels':tf.FixedLenFeature([],tf.int64),\n",
    "        'label':tf.FixedLenFeature([],tf.int64)\n",
    "    })\n",
    "\n",
    "images = tf.decode_raw(features['image_raw'],tf.uint8)\n",
    "labels = tf.cast(features['label'],tf.int32)\n",
    "pixels = tf.cast(features['pixels'],tf.int32)\n",
    "\n",
    "sess = tf.Session()\n",
    "\n",
    "# 启动多线程处理输入数据。\n",
    "coord = tf.train.Coordinator()\n",
    "threads = tf.train.start_queue_runners(sess=sess,coord=coord)\n",
    "\n",
    "for i in range(10):\n",
    "    image, label, pixel = sess.run([images, labels, pixels])\n",
    "    print(label, pixel)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
